{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dec19f84",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "print(sys.version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c6c485c",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.utils.data as data\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.datasets import CIFAR10\n",
    "from datasets import CIFAR10_truncated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4060bb94",
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "<class 'numpy.ndarray'>\n",
      "(50000, 32, 32, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([[[[ 59,  62,  63],\n",
       "          [ 43,  46,  45],\n",
       "          [ 50,  48,  43],\n",
       "          ...,\n",
       "          [158, 132, 108],\n",
       "          [152, 125, 102],\n",
       "          [148, 124, 103]],\n",
       " \n",
       "         [[ 16,  20,  20],\n",
       "          [  0,   0,   0],\n",
       "          [ 18,   8,   0],\n",
       "          ...,\n",
       "          [123,  88,  55],\n",
       "          [119,  83,  50],\n",
       "          [122,  87,  57]],\n",
       " \n",
       "         [[ 25,  24,  21],\n",
       "          [ 16,   7,   0],\n",
       "          [ 49,  27,   8],\n",
       "          ...,\n",
       "          [118,  84,  50],\n",
       "          [120,  84,  50],\n",
       "          [109,  73,  42]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[208, 170,  96],\n",
       "          [201, 153,  34],\n",
       "          [198, 161,  26],\n",
       "          ...,\n",
       "          [160, 133,  70],\n",
       "          [ 56,  31,   7],\n",
       "          [ 53,  34,  20]],\n",
       " \n",
       "         [[180, 139,  96],\n",
       "          [173, 123,  42],\n",
       "          [186, 144,  30],\n",
       "          ...,\n",
       "          [184, 148,  94],\n",
       "          [ 97,  62,  34],\n",
       "          [ 83,  53,  34]],\n",
       " \n",
       "         [[177, 144, 116],\n",
       "          [168, 129,  94],\n",
       "          [179, 142,  87],\n",
       "          ...,\n",
       "          [216, 184, 140],\n",
       "          [151, 118,  84],\n",
       "          [123,  92,  72]]],\n",
       " \n",
       " \n",
       "        [[[154, 177, 187],\n",
       "          [126, 137, 136],\n",
       "          [105, 104,  95],\n",
       "          ...,\n",
       "          [ 91,  95,  71],\n",
       "          [ 87,  90,  71],\n",
       "          [ 79,  81,  70]],\n",
       " \n",
       "         [[140, 160, 169],\n",
       "          [145, 153, 154],\n",
       "          [125, 125, 118],\n",
       "          ...,\n",
       "          [ 96,  99,  78],\n",
       "          [ 77,  80,  62],\n",
       "          [ 71,  73,  61]],\n",
       " \n",
       "         [[140, 155, 164],\n",
       "          [139, 146, 149],\n",
       "          [115, 115, 112],\n",
       "          ...,\n",
       "          [ 79,  82,  64],\n",
       "          [ 68,  70,  55],\n",
       "          [ 67,  69,  55]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[175, 167, 166],\n",
       "          [156, 154, 160],\n",
       "          [154, 160, 170],\n",
       "          ...,\n",
       "          [ 42,  34,  36],\n",
       "          [ 61,  53,  57],\n",
       "          [ 93,  83,  91]],\n",
       " \n",
       "         [[165, 154, 128],\n",
       "          [156, 152, 130],\n",
       "          [159, 161, 142],\n",
       "          ...,\n",
       "          [103,  93,  96],\n",
       "          [123, 114, 120],\n",
       "          [131, 121, 131]],\n",
       " \n",
       "         [[163, 148, 120],\n",
       "          [158, 148, 122],\n",
       "          [163, 156, 133],\n",
       "          ...,\n",
       "          [143, 133, 139],\n",
       "          [143, 134, 142],\n",
       "          [143, 133, 144]]],\n",
       " \n",
       " \n",
       "        [[[255, 255, 255],\n",
       "          [253, 253, 253],\n",
       "          [253, 253, 253],\n",
       "          ...,\n",
       "          [253, 253, 253],\n",
       "          [253, 253, 253],\n",
       "          [253, 253, 253]],\n",
       " \n",
       "         [[255, 255, 255],\n",
       "          [255, 255, 255],\n",
       "          [255, 255, 255],\n",
       "          ...,\n",
       "          [255, 255, 255],\n",
       "          [255, 255, 255],\n",
       "          [255, 255, 255]],\n",
       " \n",
       "         [[255, 255, 255],\n",
       "          [254, 254, 254],\n",
       "          [254, 254, 254],\n",
       "          ...,\n",
       "          [254, 254, 254],\n",
       "          [254, 254, 254],\n",
       "          [254, 254, 254]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[113, 120, 112],\n",
       "          [111, 118, 111],\n",
       "          [105, 112, 106],\n",
       "          ...,\n",
       "          [ 72,  81,  80],\n",
       "          [ 72,  80,  79],\n",
       "          [ 72,  80,  79]],\n",
       " \n",
       "         [[111, 118, 110],\n",
       "          [104, 111, 104],\n",
       "          [ 99, 106,  98],\n",
       "          ...,\n",
       "          [ 68,  75,  73],\n",
       "          [ 70,  76,  75],\n",
       "          [ 78,  84,  82]],\n",
       " \n",
       "         [[106, 113, 105],\n",
       "          [ 99, 106,  98],\n",
       "          [ 95, 102,  94],\n",
       "          ...,\n",
       "          [ 78,  85,  83],\n",
       "          [ 79,  85,  83],\n",
       "          [ 80,  86,  84]]],\n",
       " \n",
       " \n",
       "        ...,\n",
       " \n",
       " \n",
       "        [[[ 35, 178, 235],\n",
       "          [ 40, 176, 239],\n",
       "          [ 42, 176, 241],\n",
       "          ...,\n",
       "          [ 99, 177, 219],\n",
       "          [ 79, 147, 197],\n",
       "          [ 89, 148, 189]],\n",
       " \n",
       "         [[ 57, 182, 234],\n",
       "          [ 44, 184, 250],\n",
       "          [ 50, 183, 240],\n",
       "          ...,\n",
       "          [156, 182, 200],\n",
       "          [141, 177, 206],\n",
       "          [116, 149, 175]],\n",
       " \n",
       "         [[ 98, 197, 237],\n",
       "          [ 64, 189, 252],\n",
       "          [ 69, 192, 245],\n",
       "          ...,\n",
       "          [188, 195, 206],\n",
       "          [119, 135, 147],\n",
       "          [ 61,  79,  90]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 73,  79,  77],\n",
       "          [ 53,  63,  68],\n",
       "          [ 54,  68,  80],\n",
       "          ...,\n",
       "          [ 17,  40,  64],\n",
       "          [ 21,  36,  51],\n",
       "          [ 33,  48,  49]],\n",
       " \n",
       "         [[ 61,  68,  75],\n",
       "          [ 55,  70,  86],\n",
       "          [ 57,  79, 103],\n",
       "          ...,\n",
       "          [ 24,  48,  72],\n",
       "          [ 17,  35,  53],\n",
       "          [  7,  23,  32]],\n",
       " \n",
       "         [[ 44,  56,  73],\n",
       "          [ 46,  66,  88],\n",
       "          [ 49,  77, 105],\n",
       "          ...,\n",
       "          [ 27,  52,  77],\n",
       "          [ 21,  43,  66],\n",
       "          [ 12,  31,  50]]],\n",
       " \n",
       " \n",
       "        [[[189, 211, 240],\n",
       "          [186, 208, 236],\n",
       "          [185, 207, 235],\n",
       "          ...,\n",
       "          [175, 195, 224],\n",
       "          [172, 194, 222],\n",
       "          [169, 194, 220]],\n",
       " \n",
       "         [[194, 210, 239],\n",
       "          [191, 207, 236],\n",
       "          [190, 206, 235],\n",
       "          ...,\n",
       "          [173, 192, 220],\n",
       "          [171, 191, 218],\n",
       "          [167, 190, 216]],\n",
       " \n",
       "         [[208, 219, 244],\n",
       "          [205, 216, 240],\n",
       "          [204, 215, 239],\n",
       "          ...,\n",
       "          [175, 191, 217],\n",
       "          [172, 190, 216],\n",
       "          [169, 191, 215]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[207, 199, 181],\n",
       "          [203, 195, 175],\n",
       "          [203, 196, 173],\n",
       "          ...,\n",
       "          [135, 132, 127],\n",
       "          [162, 158, 150],\n",
       "          [168, 163, 151]],\n",
       " \n",
       "         [[198, 190, 170],\n",
       "          [189, 181, 159],\n",
       "          [180, 172, 147],\n",
       "          ...,\n",
       "          [178, 171, 160],\n",
       "          [175, 169, 156],\n",
       "          [175, 169, 154]],\n",
       " \n",
       "         [[198, 189, 173],\n",
       "          [189, 181, 162],\n",
       "          [178, 170, 149],\n",
       "          ...,\n",
       "          [195, 184, 169],\n",
       "          [196, 189, 171],\n",
       "          [195, 190, 171]]],\n",
       " \n",
       " \n",
       "        [[[229, 229, 239],\n",
       "          [236, 237, 247],\n",
       "          [234, 236, 247],\n",
       "          ...,\n",
       "          [217, 219, 233],\n",
       "          [221, 223, 234],\n",
       "          [222, 223, 233]],\n",
       " \n",
       "         [[222, 221, 229],\n",
       "          [239, 239, 249],\n",
       "          [233, 234, 246],\n",
       "          ...,\n",
       "          [223, 223, 236],\n",
       "          [227, 228, 238],\n",
       "          [210, 211, 220]],\n",
       " \n",
       "         [[213, 206, 211],\n",
       "          [234, 232, 239],\n",
       "          [231, 233, 244],\n",
       "          ...,\n",
       "          [220, 220, 232],\n",
       "          [220, 219, 232],\n",
       "          [202, 203, 215]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[150, 143, 135],\n",
       "          [140, 135, 127],\n",
       "          [132, 127, 120],\n",
       "          ...,\n",
       "          [224, 222, 218],\n",
       "          [230, 228, 225],\n",
       "          [241, 241, 238]],\n",
       " \n",
       "         [[137, 132, 126],\n",
       "          [130, 127, 120],\n",
       "          [125, 121, 115],\n",
       "          ...,\n",
       "          [181, 180, 178],\n",
       "          [202, 201, 198],\n",
       "          [212, 211, 207]],\n",
       " \n",
       "         [[122, 119, 114],\n",
       "          [118, 116, 110],\n",
       "          [120, 116, 111],\n",
       "          ...,\n",
       "          [179, 177, 173],\n",
       "          [164, 164, 162],\n",
       "          [163, 163, 161]]]], dtype=uint8),\n",
       " array([6, 9, 9, ..., 9, 1, 1]),\n",
       " array([[[[158, 112,  49],\n",
       "          [159, 111,  47],\n",
       "          [165, 116,  51],\n",
       "          ...,\n",
       "          [137,  95,  36],\n",
       "          [126,  91,  36],\n",
       "          [116,  85,  33]],\n",
       " \n",
       "         [[152, 112,  51],\n",
       "          [151, 110,  40],\n",
       "          [159, 114,  45],\n",
       "          ...,\n",
       "          [136,  95,  31],\n",
       "          [125,  91,  32],\n",
       "          [119,  88,  34]],\n",
       " \n",
       "         [[151, 110,  47],\n",
       "          [151, 109,  33],\n",
       "          [158, 111,  36],\n",
       "          ...,\n",
       "          [139,  98,  34],\n",
       "          [130,  95,  34],\n",
       "          [120,  89,  33]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 68, 124, 177],\n",
       "          [ 42, 100, 148],\n",
       "          [ 31,  88, 137],\n",
       "          ...,\n",
       "          [ 38,  97, 146],\n",
       "          [ 13,  64, 108],\n",
       "          [ 40,  85, 127]],\n",
       " \n",
       "         [[ 61, 116, 168],\n",
       "          [ 49, 102, 148],\n",
       "          [ 35,  85, 132],\n",
       "          ...,\n",
       "          [ 26,  82, 130],\n",
       "          [ 29,  82, 126],\n",
       "          [ 20,  64, 107]],\n",
       " \n",
       "         [[ 54, 107, 160],\n",
       "          [ 56, 105, 149],\n",
       "          [ 45,  89, 132],\n",
       "          ...,\n",
       "          [ 24,  77, 124],\n",
       "          [ 34,  84, 129],\n",
       "          [ 21,  67, 110]]],\n",
       " \n",
       " \n",
       "        [[[235, 235, 235],\n",
       "          [231, 231, 231],\n",
       "          [232, 232, 232],\n",
       "          ...,\n",
       "          [233, 233, 233],\n",
       "          [233, 233, 233],\n",
       "          [232, 232, 232]],\n",
       " \n",
       "         [[238, 238, 238],\n",
       "          [235, 235, 235],\n",
       "          [235, 235, 235],\n",
       "          ...,\n",
       "          [236, 236, 236],\n",
       "          [236, 236, 236],\n",
       "          [235, 235, 235]],\n",
       " \n",
       "         [[237, 237, 237],\n",
       "          [234, 234, 234],\n",
       "          [234, 234, 234],\n",
       "          ...,\n",
       "          [235, 235, 235],\n",
       "          [235, 235, 235],\n",
       "          [234, 234, 234]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 87,  99,  89],\n",
       "          [ 43,  51,  37],\n",
       "          [ 19,  23,  11],\n",
       "          ...,\n",
       "          [169, 184, 179],\n",
       "          [182, 197, 193],\n",
       "          [188, 202, 201]],\n",
       " \n",
       "         [[ 82,  96,  82],\n",
       "          [ 46,  57,  36],\n",
       "          [ 36,  44,  22],\n",
       "          ...,\n",
       "          [174, 189, 183],\n",
       "          [185, 200, 196],\n",
       "          [187, 202, 200]],\n",
       " \n",
       "         [[ 85, 101,  83],\n",
       "          [ 62,  75,  48],\n",
       "          [ 58,  67,  38],\n",
       "          ...,\n",
       "          [168, 183, 178],\n",
       "          [180, 195, 191],\n",
       "          [186, 200, 199]]],\n",
       " \n",
       " \n",
       "        [[[158, 190, 222],\n",
       "          [158, 187, 218],\n",
       "          [139, 166, 194],\n",
       "          ...,\n",
       "          [228, 231, 234],\n",
       "          [237, 239, 243],\n",
       "          [238, 241, 246]],\n",
       " \n",
       "         [[170, 200, 229],\n",
       "          [172, 199, 226],\n",
       "          [151, 176, 201],\n",
       "          ...,\n",
       "          [232, 232, 236],\n",
       "          [246, 246, 250],\n",
       "          [246, 247, 251]],\n",
       " \n",
       "         [[174, 201, 225],\n",
       "          [176, 200, 222],\n",
       "          [157, 179, 199],\n",
       "          ...,\n",
       "          [230, 229, 232],\n",
       "          [250, 249, 251],\n",
       "          [245, 244, 247]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 31,  40,  45],\n",
       "          [ 30,  39,  44],\n",
       "          [ 26,  35,  40],\n",
       "          ...,\n",
       "          [ 37,  40,  46],\n",
       "          [  9,  13,  14],\n",
       "          [  4,   7,   5]],\n",
       " \n",
       "         [[ 23,  34,  39],\n",
       "          [ 27,  38,  43],\n",
       "          [ 25,  36,  41],\n",
       "          ...,\n",
       "          [ 19,  20,  24],\n",
       "          [  4,   6,   3],\n",
       "          [  5,   7,   3]],\n",
       " \n",
       "         [[ 28,  41,  47],\n",
       "          [ 30,  43,  50],\n",
       "          [ 32,  45,  52],\n",
       "          ...,\n",
       "          [  5,   6,   8],\n",
       "          [  4,   5,   3],\n",
       "          [  7,   8,   7]]],\n",
       " \n",
       " \n",
       "        ...,\n",
       " \n",
       " \n",
       "        [[[ 20,  15,  12],\n",
       "          [ 19,  14,  11],\n",
       "          [ 15,  14,  11],\n",
       "          ...,\n",
       "          [ 10,   9,   7],\n",
       "          [ 12,  11,   9],\n",
       "          [ 13,  12,  10]],\n",
       " \n",
       "         [[ 21,  16,  13],\n",
       "          [ 20,  16,  13],\n",
       "          [ 18,  17,  12],\n",
       "          ...,\n",
       "          [ 10,   9,   7],\n",
       "          [ 10,   9,   7],\n",
       "          [ 12,  11,   9]],\n",
       " \n",
       "         [[ 21,  16,  13],\n",
       "          [ 21,  17,  12],\n",
       "          [ 20,  18,  11],\n",
       "          ...,\n",
       "          [ 12,  11,   9],\n",
       "          [ 12,  11,   9],\n",
       "          [ 13,  12,  10]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 33,  25,  13],\n",
       "          [ 34,  26,  15],\n",
       "          [ 34,  26,  15],\n",
       "          ...,\n",
       "          [ 28,  25,  52],\n",
       "          [ 29,  25,  58],\n",
       "          [ 23,  20,  42]],\n",
       " \n",
       "         [[ 33,  25,  14],\n",
       "          [ 34,  26,  15],\n",
       "          [ 34,  26,  15],\n",
       "          ...,\n",
       "          [ 27,  24,  52],\n",
       "          [ 27,  24,  56],\n",
       "          [ 25,  22,  47]],\n",
       " \n",
       "         [[ 31,  23,  12],\n",
       "          [ 32,  24,  13],\n",
       "          [ 33,  25,  14],\n",
       "          ...,\n",
       "          [ 24,  23,  50],\n",
       "          [ 26,  23,  53],\n",
       "          [ 25,  20,  47]]],\n",
       " \n",
       " \n",
       "        [[[ 25,  40,  12],\n",
       "          [ 15,  36,   3],\n",
       "          [ 23,  41,  18],\n",
       "          ...,\n",
       "          [ 61,  82,  78],\n",
       "          [ 92, 113, 112],\n",
       "          [ 75,  89,  92]],\n",
       " \n",
       "         [[ 12,  25,   6],\n",
       "          [ 20,  37,   7],\n",
       "          [ 24,  36,  15],\n",
       "          ...,\n",
       "          [115, 134, 138],\n",
       "          [149, 168, 177],\n",
       "          [104, 117, 131]],\n",
       " \n",
       "         [[ 12,  25,  11],\n",
       "          [ 15,  29,   6],\n",
       "          [ 34,  40,  24],\n",
       "          ...,\n",
       "          [154, 172, 182],\n",
       "          [157, 175, 192],\n",
       "          [116, 129, 151]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[100, 129,  81],\n",
       "          [103, 132,  84],\n",
       "          [104, 134,  86],\n",
       "          ...,\n",
       "          [ 97, 128,  84],\n",
       "          [ 98, 126,  84],\n",
       "          [ 91, 121,  79]],\n",
       " \n",
       "         [[103, 132,  83],\n",
       "          [104, 131,  83],\n",
       "          [107, 135,  87],\n",
       "          ...,\n",
       "          [101, 132,  87],\n",
       "          [ 99, 127,  84],\n",
       "          [ 92, 121,  79]],\n",
       " \n",
       "         [[ 95, 126,  78],\n",
       "          [ 95, 123,  76],\n",
       "          [101, 128,  81],\n",
       "          ...,\n",
       "          [ 93, 124,  80],\n",
       "          [ 95, 123,  81],\n",
       "          [ 92, 120,  80]]],\n",
       " \n",
       " \n",
       "        [[[ 73,  78,  75],\n",
       "          [ 98, 103, 113],\n",
       "          [ 99, 106, 114],\n",
       "          ...,\n",
       "          [135, 150, 152],\n",
       "          [135, 149, 154],\n",
       "          [203, 215, 223]],\n",
       " \n",
       "         [[ 69,  73,  70],\n",
       "          [ 84,  89,  97],\n",
       "          [ 68,  75,  81],\n",
       "          ...,\n",
       "          [ 85,  95,  89],\n",
       "          [ 71,  82,  80],\n",
       "          [120, 133, 135]],\n",
       " \n",
       "         [[ 69,  73,  70],\n",
       "          [ 90,  95, 100],\n",
       "          [ 62,  71,  74],\n",
       "          ...,\n",
       "          [ 74,  81,  70],\n",
       "          [ 53,  62,  54],\n",
       "          [ 62,  74,  69]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[123, 128,  96],\n",
       "          [132, 132, 102],\n",
       "          [129, 128, 100],\n",
       "          ...,\n",
       "          [108, 107,  88],\n",
       "          [ 62,  60,  55],\n",
       "          [ 27,  27,  28]],\n",
       " \n",
       "         [[115, 121,  91],\n",
       "          [123, 124,  95],\n",
       "          [129, 126,  99],\n",
       "          ...,\n",
       "          [115, 116,  94],\n",
       "          [ 66,  65,  59],\n",
       "          [ 27,  27,  27]],\n",
       " \n",
       "         [[116, 120,  90],\n",
       "          [121, 122,  94],\n",
       "          [129, 128, 101],\n",
       "          ...,\n",
       "          [116, 115,  94],\n",
       "          [ 68,  65,  58],\n",
       "          [ 27,  26,  26]]]], dtype=uint8),\n",
       " array([3, 8, 8, ..., 5, 1, 7]))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_cifar10_data(datadir):\n",
    "\n",
    "    transform = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "    cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform)\n",
    "    cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform)\n",
    "\n",
    "    X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target\n",
    "    X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target\n",
    "\n",
    "    # y_train = y_train.numpy()\n",
    "    # y_test = y_test.numpy()\n",
    "    \n",
    "    print(type(X_train))\n",
    "    print(X_train.shape)\n",
    "\n",
    "    return (X_train, y_train, X_test, y_test)\n",
    "\n",
    "load_cifar10_data(\"./data/\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
